import copy
import numpy as np
import time
import mindspore as ms
import mindspore.numpy as msnp
from mindspore import nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore import dataset as ds
from mindspore.train import Model
from mindspore import context
from mindspore.common.initializer import initializer
from mindspore import ms_function


class AlphaChemlLossCell(nn.Cell):
    def __init__(self,
        coor_net,
        prob_net,
        batch_size,
        num_contrast,
        weight_positive=0.4,
        eps=1e-4,
        # loss_fn
    ):
        super().__init__(auto_prefix=False)
        self.coor_net = coor_net
        self.prob_net = prob_net
        # self._loss_fn = loss_fn

        self.batch_size = batch_size
        self.num_contrast = num_contrast

        self.weight_positive = weight_positive
        self.eps = eps

        self.log = P.Log()
        self.erfc = P.Erfc()
        self.concat = P.Concat(axis=-1)
        self.keep_sum = P.ReduceSum(keep_dims=True)
        self.reduce_sum = P.ReduceSum(keep_dims=False)
        self.keep_mean = P.ReduceMean(keep_dims=True)
        self.reduce_mean = P.ReduceMean(keep_dims=False)

        self.grad = C.GradOperation()

        self.scalar_summary = P.ScalarSummary()
        self.tensor_summary = P.TensorSummary()

    def _calc_phase_spcae(self,R,V):
        # [B,A,3] -> [B,1]
        r = self.coor_net(R)
        # [B,A,3]
        dr = self.grad(self.coor_net)(R)

        # [B,A,3] -> [B,A*3]
        dotshape = (R.shape[0],-1)
        # [B,1]
        v = C.batch_dot(dr.reshape(dotshape),V.reshape(dotshape))

        # [B,2]
        forward = self.concat((r,v))
        backward = self.concat((r,-v))

        return forward,backward

    def _calc_trans_path_prob(self,R,V,Mf,Mb):

        # [B*C,2]
        qf,qb = self._calc_phase_spcae(R,V)

        # [B*C,1]
        pf = 0.5 * self.erfc(-self.prob_net(qf))
        pb = 0.5 * self.erfc(-self.prob_net(qb))

        # [B*C,1]
        ptp = pf * (1 - pb) + (1.0 - pf) * pb
        # [B,C,1]
        ptp = ptp.reshape(-1,self.num_contrast,1)
        # [B,1]
        # ptp_sum = self.reduce_sum(ptp,-2)
        ptp_sum = self.reduce_mean(ptp,-2)
        ptp = ptp[:,0,:]

        # [B,C,1]
        pf = pf.reshape(-1,self.num_contrast,1)
        pb = pb.reshape(-1,self.num_contrast,1)
        pf = pf[:,0,:]
        pb = pb[:,0,:]

        log_1_pf = -self.log(C.clip_by_value(1-pf,self.eps,1))
        log_pf = -self.log(C.clip_by_value(pf,self.eps,1))
        likeli_forw = F.select(Mf,log_1_pf,log_pf)
        likeli_forw = self.reduce_mean(likeli_forw)

        log_1_pb = -self.log(C.clip_by_value(1-pb,self.eps,1))
        log_pb = -self.log(C.clip_by_value(pb,self.eps,1))
        likeli_back = F.select(Mb,log_1_pb,log_pb)
        likeli_back = self.reduce_mean(likeli_back)

        likeli = (likeli_forw + likeli_back) / 2
        # likeli = self.reduce_sum(likeli_forw+likeli_back)

        return ptp, ptp_sum, likeli

    def construct(self, Rp, Vp, Mp, Rn, Vn, Mn):
        """Compute the loss function of the AlphaChem.

        Args:
            Rp     (mindspore.Tensor[float], [B*C, A, 3]):    Cartesian coordinates for each atom of positive samples.
            Vp     (mindspore.Tensor[float], [B*C, A, 3]):    Velocities for each atom of positive sample.
            Mp     (mindspore.Tensor[bool],  [B*C, 1]):       Mask for shooting results (True for A and False for B) of positive sample.
            Rn     (mindspore.Tensor[float], [B*C, A, 3]):    Cartesian coordinates for each atom of negative sample.
            Vn     (mindspore.Tensor[float], [B*C, A, 3]):    Velocities for each atom of negative sample.
            Mn     (mindspore.Tensor[bool],  [B*C, 1]):       Mask for shooting results (True for A and False for B) of negative sample.
            
            B:  Batch size
            C:  Number of contrastive samples
            A:  Number of input atoms

        Returns:
            loss mindspore.Tensor[float], [B,1]: loss function of AlphaChem
 
        """

        Mp = Mp.reshape(-1,self.num_contrast,1)
        Mn = Mn.reshape(-1,self.num_contrast,1)
        Mp = Mp[:,0,:]
        Mn = Mn[:,0,:]

        # [B,1], [B,1], [,]
        ptp_pos,sum_pos,likeli_pos = self._calc_trans_path_prob(Rp,Vp,Mp,F.logical_not(Mp))
        ptp_neg,sum_neg,likeli_neg = self._calc_trans_path_prob(Rn,Vn,Mn,Mn)

        # [B,1]
        loss_pos = -self.log(C.clip_by_value(ptp_pos /(ptp_pos + sum_neg),self.eps,1))
        loss_neg = -self.log(C.clip_by_value((1.-ptp_neg) /((1.-ptp_neg) + (1.-sum_pos)),self.eps,1))

        loss_pos = self.reduce_mean(loss_pos)
        loss_neg = self.reduce_mean(loss_neg)

        loss_contrast = 0.5 * (loss_pos + loss_neg)

        loss_likelihood = self.weight_positive * likeli_pos + likeli_neg

        self.scalar_summary('loss_contrast',loss_contrast)
        self.scalar_summary('loss_likelihood',loss_likelihood)

        # print(loss_contrast,loss_likelihood)

        # [B,1] = [B,1] + [1,1]
        loss =  loss_contrast + loss_likelihood * 0.2

        return loss



class AlphaChemlLossCellWithMemorySet(nn.Cell):
    def __init__(self,
        coor_net,
        prob_net,
        batch_size,
        num_contrast,
        weight_positive=0.4,
        eps=1e-4,
        momentum_decay=1e-3,
        R_pos_init=None,
        V_pos_init=None,
        R_neg_init=None,
        V_neg_init=None,
        # loss_fn
    ):
        super().__init__(auto_prefix=False)
        self.coor_net = coor_net
        self.prob_net = prob_net
        # self._loss_fn = loss_fn

        self.momentum_decay = momentum_decay
        self.momentum_coef = 1.0 - momentum_decay

        self.coor_mem = None
        self.prob_mem = None
        self.network_params = None
        self.memset_params = None
        
        self.coor_mem = copy.deepcopy(coor_net)
        self.prob_mem = copy.deepcopy(prob_net)

        for m,n in zip(self.coor_mem.get_parameters(),self.coor_net.get_parameters()):
            m.name = 'memset.' + m.name
            m.requires_grad = n.requires_grad

        for m,n in zip(self.prob_mem.get_parameters(),self.prob_net.get_parameters()):
            m.name = 'memset.' + m.name
            m.requires_grad = n.requires_grad

        self.network_params = self.coor_net.trainable_params() + self.prob_net.trainable_params()
        self.memset_params = self.coor_mem.trainable_params() + self.prob_mem.trainable_params()

        for p in self.memset_params:
            p.requires_grad = False

        # [C,B,1]
        self.memset_pos = ms.Parameter(msnp.ones((num_contrast,batch_size,1),ms.float32),name='memset_pos',requires_grad=False)
        self.memset_neg = ms.Parameter(msnp.zeros((num_contrast,batch_size,1),ms.float32),name='memset_neg',requires_grad=False)

        self.assign = P.Assign()
        self.grad = C.GradOperation()
        self.concat = P.Concat(axis=-1)
        self.erfc = P.Erfc()

        if R_pos_init is not None and V_pos_init is not None:
            memset_pos = self._calc_memset(R_pos_init,V_pos_init)
            self.assign(self.memset_pos,memset_pos)

        if R_neg_init is not None and V_neg_init is not None:
            memset_neg = self._calc_memset(R_neg_init,V_neg_init)
            self.assign(self.memset_neg,memset_neg)

        self.batch_size = batch_size
        self.num_contrast = num_contrast

        self.weight_positive = weight_positive
        self.eps = eps

        self.log = P.Log()
        self.keep_sum = P.ReduceSum(keep_dims=True)
        self.reduce_sum = P.ReduceSum(keep_dims=False)
        self.keep_mean = P.ReduceMean(keep_dims=True)
        self.reduce_mean = P.ReduceMean(keep_dims=False)

        self.scalar_summary = P.ScalarSummary()
        self.tensor_summary = P.TensorSummary()

        self.count = ms.Parameter(Tensor(0,ms.int32),name='count',requires_grad=False)
        self.one = Tensor(1,ms.float32)
        
        self.print = P.Print()
        self.mod = P.Mod()

    def _update_memset_networks(self):
        for n,m in zip(self.network_params,self.memset_params):
            m = self.momentum_coef * m + self.momentum_decay * n
        return self

    def _calc_memset(self,R,V):
        # [C,B,A,3]
        shape = (self.num_contrast,self.batch_size) + R.shape[-2:]
        R = F.reshape(R,shape)
        V = F.reshape(V,shape)
        memset = []
        for p,v in zip(R,V):
            # [B,1]
            memset.append(self.calc_ptp_mem(p,v))
        # [C,B,1]
        return F.stack(memset)

    def _calc_prob(self,r,dr,V,net):
        
        # [B,1]
        v = C.batch_dot(dr.reshape(self.batch_size,-1),V.reshape(self.batch_size,-1))

        # [B,2]
        qf = self.concat((r,v))
        qb = self.concat((r,-v))

        # [B,1]
        pf = 0.5 * self.erfc(-net(qf))
        pb = 0.5 * self.erfc(-net(qb))

        return pf,pb

    def _calc_likeli(self,pf,pb,Mf,Mb):
        log_1_pf = -self.log(C.clip_by_value(1-pf,self.eps,1))
        log_pf = -self.log(C.clip_by_value(pf,self.eps,1))
        likeli_forw = F.select(Mf,log_1_pf,log_pf)
        likeli_forw = self.reduce_mean(likeli_forw)

        log_1_pb = -self.log(C.clip_by_value(1-pb,self.eps,1))
        log_pb = -self.log(C.clip_by_value(pb,self.eps,1))
        likeli_back = F.select(Mb,log_1_pb,log_pb)
        likeli_back = self.reduce_mean(likeli_back)

        return (likeli_forw + likeli_back) / 2

    @ms_function
    def _calc_out_and_grad(self,net,R):
        # [B,A,3] -> [B,1]
        r = net(R)
        # [B,A,3]
        dr = self.grad(net)(R)
        return r,dr

    def calc_ptp(self,R,V,Mf,Mb):
        # r,dr = self._calc_out_and_grad(self.coor_net,R)
        # [B,A,3] -> [B,1]
        r = self.coor_net(R)
        # [B,A,3]
        dr = self.grad(self.coor_net)(R)

        pf,pb = self._calc_prob(r,dr,V,self.prob_net)
        # [B,1]
        ptp = pf * (1 - pb) + (1.0 - pf) * pb

        likeli = self._calc_likeli(pf,pb,Mf,Mb)

        return ptp, likeli

    def calc_ptp_mem(self,R,V):
        # [B,A,3] -> [B,1]
        r = self.coor_mem(R)
        # [B,A,3]
        dr = self.grad(self.coor_mem)(R)

        pf,pb = self._calc_prob(r,dr,V,self.prob_net)
        # [B,1]
        ptp = pf * (1 - pb) + (1.0 - pf) * pb

        return ptp

    def construct(self, Rp, Vp, Mp, Ip, Rn, Vn, Mn, In):
        """Compute the loss function of the AlphaChem.

        Args:
            Rp     (mindspore.Tensor[float], [B*C, A, 3]):    Cartesian coordinates for each atom of positive samples.
            Vp     (mindspore.Tensor[float], [B*C, A, 3]):    Velocities for each atom of positive sample.
            Mp     (mindspore.Tensor[bool],  [B*C, 1]):       Mask for shooting results (True for A and False for B) of positive sample.
            Rn     (mindspore.Tensor[float], [B*C, A, 3]):    Cartesian coordinates for each atom of negative sample.
            Vn     (mindspore.Tensor[float], [B*C, A, 3]):    Velocities for each atom of negative sample.
            Mn     (mindspore.Tensor[bool],  [B*C, 1]):       Mask for shooting results (True for A and False for B) of negative sample.
            
            B:  Batch size
            C:  Number of contrastive samples
            A:  Number of input atoms

        Returns:
            loss mindspore.Tensor[float], [B,1]: loss function of AlphaChem
 
        """

        count = self.mod(self.count + 1,self.num_contrast)
        count = self.assign(self.count,count)
        count += 0
        # print(self.count)

        self._update_memset_networks()

        # [B,1], [B,1], [,]
        ptp_pos,likeli_pos = self.calc_ptp(Rp,Vp,Mp,F.logical_not(Mp))
        ptp_neg,likeli_neg = self.calc_ptp(Rn,Vn,Mn,Mn)

        # [B,1]
        mem_pos = self.calc_ptp_mem(Rp,Vp)
        mem_neg = self.calc_ptp_mem(Rn,Vn)

        # [C,B,1]
        self.memset_pos[count] = mem_pos
        self.memset_neg[count] = mem_neg

        # [C,B,1] -> [B,1]
        sum_pos = self.reduce_mean(self.memset_pos,0)
        sum_neg = self.reduce_mean(self.memset_neg,0)

        sum_pos = sum_pos[Ip]
        sum_neg = sum_neg[In]

        # [B,1]
        loss_pos = -self.log(C.clip_by_value(ptp_pos /(ptp_pos + sum_neg),self.eps,1))
        loss_neg = -self.log(C.clip_by_value((1.-ptp_neg) /((1.-ptp_neg) + (1.-sum_pos)),self.eps,1))

        loss_pos = self.reduce_mean(loss_pos)
        loss_neg = self.reduce_mean(loss_neg)

        loss_contrast = 0.5 * (loss_pos + loss_neg)

        loss_likelihood = self.weight_positive * likeli_pos + likeli_neg

        self.scalar_summary('loss_contrast',loss_contrast)
        self.scalar_summary('loss_likelihood',loss_likelihood)

        # print(loss_contrast,loss_likelihood)

        # [B,1] = [B,1] + [1,1]
        loss =  loss_contrast + loss_likelihood * 0.2

        return loss

class TransitionPathProbability(nn.Cell):
    def __init__(self,
        coor_net,
        prob_net,
        return_all_probs=True,
    ):
        super().__init__()

        self.coor_net = coor_net
        self.prob_net = prob_net

        self.return_forward_prob = return_all_probs

        self.erfc = P.Erfc()
        self.grad = C.GradOperation()
        self.concat = P.Concat(axis=-1)

    def construct(self,R,V):
        # [B,A,3] -> [B,1]
        r = self.coor_net(R)
        # [B,A,3]
        dr = self.grad(self.coor_net)(R)
        # [B,1]
        batch_size = R.shape[0]
        v = C.batch_dot(dr.reshape(batch_size,-1),V.reshape(batch_size,-1))

        # [B,2]
        qf = self.concat((r,v))
        qb = self.concat((r,-v))

        # [B,1]
        pf = 0.5 * self.erfc(-self.prob_net(qf))
        pb = 0.5 * self.erfc(-self.prob_net(qb))

        ptp = pf * (1 - pb) + (1.0 - pf) * pb

        if self.return_all_probs:
            return ptp,pf,pb
        else:
            return ptp

class AlphaChemlLossCellWithMemorySet2(nn.Cell):
    def __init__(self,
        ptp_net,
        ptp_mem,
        batch_size,
        num_contrast,
        weight_positive=0.4,
        eps=1e-4,
        momentum_decay=1e-3,
        memset_pos_init=None,
        memset_neg_init=None,
        # loss_fn
    ):
        super().__init__(auto_prefix=False)
        self.ptp_net = ptp_net
        self.ptp_mem = ptp_mem

        self.momentum_decay = momentum_decay
        self.momentum_coef = 1.0 - momentum_decay

        self.network_params = None
        self.memset_params = None

        # self.ptp_mem = copy.deepcopy(ptp_net)
        # self.ptp_mem.return_forward_prob = False

        # for m,n in zip(self.ptp_mem.get_parameters(),self.ptp_net.get_parameters()):
        #     m.name = 'memset.' + m.name
        #     if n.requires_grad == False:
        #         m.requires_grad = False

        self.network_params = self.ptp_net.trainable_params()
        self.memset_params = self.ptp_mem.trainable_params()

        for p in self.memset_params:
            p.requires_grad = False

        for i,p in enumerate(self.network_params):
            print(i,p)

        for i,p in enumerate(self.memset_params):
            print(i,p)

        memset_shape = (num_contrast,batch_size,1)
        # [C,B,1]
        if memset_pos_init is None:
            self.memset_pos = ms.Parameter(msnp.ones((num_contrast,batch_size,1),ms.float32),name='memset_pos',requires_grad=False)
        else:
            if memset_pos_init.shape == memset_shape:
                self.memset_pos = ms.Parameter(memset_pos_init,'memset_pos',requires_grad=False)
            else:
                raise ValueError('The shape of memset_pos_init should be '+str(memset_shape)+' but get '+str(memset_pos_init.shape))

        if memset_neg_init is None:
            self.memset_neg = ms.Parameter(msnp.zeros((num_contrast,batch_size,1),ms.float32),name='memset_neg',requires_grad=False)
        else:
            if memset_neg_init.shape == memset_shape:
                self.memset_neg = ms.Parameter(memset_neg_init,'memset_neg',requires_grad=False)
            else:
                raise ValueError('The shape of memset_neg_init should be '+str(memset_shape)+' but get '+str(memset_neg_init.shape))

        self.assign = P.Assign()

        self.batch_size = batch_size
        self.num_contrast = num_contrast

        self.weight_positive = weight_positive
        self.eps = eps

        self.log = P.Log()
        
        self.keep_sum = P.ReduceSum(keep_dims=True)
        self.reduce_sum = P.ReduceSum(keep_dims=False)
        self.keep_mean = P.ReduceMean(keep_dims=True)
        self.reduce_mean = P.ReduceMean(keep_dims=False)
        
        self.scalar_summary = P.ScalarSummary()
        self.tensor_summary = P.TensorSummary()

        self.count = ms.Parameter(Tensor(0,ms.int32),name='count',requires_grad=False)
        self.one = Tensor(1,ms.float32)
        
        self.print = P.Print()
        self.mod = P.Mod()

    def _update_memset_networks(self):
        for n,m in zip(self.network_params,self.memset_params):
            m = self.momentum_coef * m + self.momentum_decay * n
        return self

    def _calc_likeli(self,pf,pb,Mf,Mb):
        log_1_pf = -self.log(C.clip_by_value(1-pf,self.eps,1))
        log_pf = -self.log(C.clip_by_value(pf,self.eps,1))
        likeli_forw = F.select(Mf,log_1_pf,log_pf)
        likeli_forw = self.reduce_mean(likeli_forw)

        log_1_pb = -self.log(C.clip_by_value(1-pb,self.eps,1))
        log_pb = -self.log(C.clip_by_value(pb,self.eps,1))
        likeli_back = F.select(Mb,log_1_pb,log_pb)
        likeli_back = self.reduce_mean(likeli_back)

        return (likeli_forw + likeli_back) / 2

    def construct(self, Rp, Vp, Mp, Ip, Rn, Vn, Mn, In):
        """Compute the loss function of the AlphaChem.

        Args:
            Rp     (mindspore.Tensor[float], [B*C, A, 3]):    Cartesian coordinates for each atom of positive samples.
            Vp     (mindspore.Tensor[float], [B*C, A, 3]):    Velocities for each atom of positive sample.
            Mp     (mindspore.Tensor[bool],  [B*C, 1]):       Mask for shooting results (True for A and False for B) of positive sample.
            Rn     (mindspore.Tensor[float], [B*C, A, 3]):    Cartesian coordinates for each atom of negative sample.
            Vn     (mindspore.Tensor[float], [B*C, A, 3]):    Velocities for each atom of negative sample.
            Mn     (mindspore.Tensor[bool],  [B*C, 1]):       Mask for shooting results (True for A and False for B) of negative sample.
            
            B:  Batch size
            C:  Number of contrastive samples
            A:  Number of input atoms

        Returns:
            loss mindspore.Tensor[float], [B,1]: loss function of AlphaChem
 
        """

        count = self.mod(self.count + 1,self.num_contrast)
        count = self.assign(self.count,count)
        count += 0
        # print(self.count)

        self._update_memset_networks()

        # [B,1], [B,1], [,]
        ptp_pos,pf_pos,pb_pos = self.ptp_net(Rp,Vp)
        ptp_neg,pf_neg,pb_neg = self.ptp_net(Rn,Vn)
        
        likeli_pos = self._calc_likeli(pf_pos,pb_pos,Mp,F.logical_not(Mp))
        likeli_neg = self._calc_likeli(pf_neg,pb_neg,Mn,Mn)

        # [B,1]
        mem_pos = self.ptp_mem(Rp,Vp)
        mem_neg = self.ptp_mem(Rn,Vn)

        # [C,B,1]
        self.memset_pos[count] = mem_pos
        self.memset_neg[count] = mem_neg

        # [C,B,1] -> [B,1]
        sum_pos = self.reduce_mean(self.memset_pos,0)
        sum_neg = self.reduce_mean(self.memset_neg,0)

        sum_pos = sum_pos[Ip]
        sum_neg = sum_neg[In]

        # [B,1]
        loss_pos = -self.log(C.clip_by_value(ptp_pos /(ptp_pos + sum_neg),self.eps,1))
        loss_neg = -self.log(C.clip_by_value((1.-ptp_neg) /((1.-ptp_neg) + (1.-sum_pos)),self.eps,1))

        loss_pos = self.reduce_mean(loss_pos)
        loss_neg = self.reduce_mean(loss_neg)

        loss_contrast = 0.5 * (loss_pos + loss_neg)

        loss_likelihood = self.weight_positive * likeli_pos + likeli_neg

        self.scalar_summary('loss_contrast',loss_contrast)
        self.scalar_summary('loss_likelihood',loss_likelihood)

        # print(loss_contrast,loss_likelihood)

        # [B,1] = [B,1] + [1,1]
        loss =  loss_contrast + loss_likelihood * 0.2

        return loss
